import pickle as pkl
import numpy as np
import scipy.stats as sp
import csv
import matplotlib.pyplot as plt 
import glob
import seaborn as sns

#sns.set_theme(style="whitegrid")

col = {
    "event_threshold": "red",
    "independent_belief": "yellow",
    "no-rm": "gray",
    "perfect_rm": "black",
    "rm_detector": "blue",
}

label = {
    "event_threshold": "Event Thresholding",
    "independent_belief": "Independent Belief Updating", 
    "no-rm": "Recurrent PPO",
    "perfect_rm": "Perfect RM",
    "rm_detector": "RMSM (Ours)",
}

frames = {
    "Traffic-v0": 25000000,
    "Kitchen-v2": 10000000
}

def analyze_learning_curves(task, smooth=0):

    def load_data(baseline, task=task, start_model=1, end_model=8):
        frame_denomination = 16384
        vals = []

        min_len = 100000000000
        max_frames = frames[task]

        for model in range(start_model, end_model+1):
            files = glob.glob("storage-final/%s/%s-%s-seed%d/train/log.csv"%(task, baseline, task, model))
            assert(len(files)==1)
            csv_file = open(files[0],"r")
            csv_reader = csv.DictReader(csv_file, delimiter=",")

            returns = dict()
            for row in csv_reader:
                if row['return_mean'] != 'return_mean':
                    frame_rounded = round(float(row['frames']) / frame_denomination) * frame_denomination
                    returns[frame_rounded] = float(row['return_mean'])
                    if frame_rounded > max_frames:
                        break
            i = 0
            min_len = min(min_len, len(returns.keys()))
            for f in sorted(returns.keys()):
                if i >= len(vals):
                    vals.append([])
                vals[i].append((f, returns[f]))
                i += 1

        vals = np.array(vals[:min_len])

        xs = vals[:, :, 0].mean(axis=1)
        y_means = np.zeros(vals.shape[:1])
        y_errors = np.zeros(vals.shape[:1])

        for i in range(vals.shape[0]):
            # n = vals.shape[1]
            # lo, hi = sp.t.interval(0.9, n-1, loc=vals[i,:,1].mean(), scale=sp.sem(vals[i,:,1]))

            i_min = max(0, i - smooth)
            i_max = min(i+smooth, vals.shape[0]-1)

            ys = vals[i_min:i_max+1,:,1].mean(axis=0)
            y_means[i] = ys.mean()
            y_errors[i] = sp.sem(ys)

            # y_means[i] = vals[i,:,1].mean()
            # y_errors[i] = sp.sem(vals[i,:,1]) #(lo - hi)/2

        return xs, y_means, y_errors, baseline

    event_threshold = load_data('event_threshold')
    independent_belief = load_data('independent_belief')
    no_rm = load_data('no-rm')
    perfect_rm = load_data('perfect_rm')
    rm_detector = load_data('rm_detector')

    def plot_baseline(returns):
        xs, y_means, y_errors, baseline = returns
        plt.plot(xs, y_means, color=col[baseline], label=label[baseline], linewidth=0.5)

        plt.fill_between(xs, y_means-y_errors, y_means+y_errors, alpha=0.2, color=col[baseline])
    plt.title(task)

    plot_baseline(event_threshold)
    plot_baseline(independent_belief)
    plot_baseline(no_rm)
    plot_baseline(perfect_rm)
    plot_baseline(rm_detector)

    legend = plt.legend()
    for legobj in legend.legendHandles:
        legobj.set_linewidth(2.0)
    plt.show()




# analyze_learning_curves('Kitchen-v2')
# analyze_learning_curves('Traffic-v0')


def analyze_beliefs(task):

    def load_data(baseline, task=task, start_model=1, end_model=8):
        vals = []

        for model in range(start_model, end_model+1):
            files = glob.glob("storage-final/%s/%s-%s-seed%d/train/%s.pkl"%(task, baseline, task, model, "kitchen_belief" if task=="Kitchen-v2" else "traffic_belief"))
            assert(len(files)==1)
            data = pkl.load(open(files[0],"rb"))
            vals += data
        
        vals = np.array(vals)
        print(len(vals))
        return vals.mean(), sp.sem(vals)

    event_threshold = load_data('event_threshold')
    independent_belief = load_data('independent_belief')
    rm_detector = load_data('rm_detector')


    baselines = ['RMSM', 'Independent Belief Updating', 'Thresholding']
    heights = [rm_detector[0], independent_belief[0], event_threshold[0]]
    errors = [rm_detector[1], independent_belief[1], event_threshold[1]]

    plt.bar(baselines, heights, yerr=errors)

    plt.title(task)
    plt.ylabel("Total variation distance")

    legend = plt.legend()
    for legobj in legend.legendHandles:
        legobj.set_linewidth(2.0)
    plt.show()

analyze_beliefs('Kitchen-v2')
analyze_beliefs('Traffic-v0')
